"""
Author: Benny
Date: Nov 2019
"""
from data_utils.ModelNetDataLoader import ModelNetDataLoader
import argparse
import numpy as np
import os
import torch
import datetime
import logging
from pathlib import Path
from tqdm import tqdm
import sys
import provider
import importlib
import shutil

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))


def parse_args():
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('PointNet')
    parser.add_argument('--batch_size', type=int, default=24, help='batch size in training [default: 24]')
    parser.add_argument('--model', default='pointnet_cls', help='model name [default: pointnet_cls]')
    parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training [default: 200]')
    parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training [default: 0.001]')
    parser.add_argument('--gpu', type=str, default='0', help='specify gpu device [default: 0]')
    parser.add_argument('--num_point', type=int, default=1024, help='Point Number [default: 1024]')
    parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training [default: Adam]')
    parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
    parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate [default: 1e-4]')
    parser.add_argument('--normal', action='store_true', default=False,
                        help='Whether to use normal information [default: False]')
    return parser.parse_args()


def test(model, loader, num_class=40, rotate=False):
    classifier = model.eval()

    if rotate:
        thetas = np.linspace(0, 2 * np.pi, num=4)
        gammas = np.linspace(0, 2 * np.pi, num=4)
    else:
        thetas, gammas = [0], [0]

    mean_correct_avg = []
    class_acc_avg = np.zeros((num_class, 3))
    mean_correct_worst = []
    class_acc_worst = np.zeros((num_class, 3))

    for j, data in tqdm(enumerate(loader), total=len(loader)):
        points, target = data
        target = target[:, 0]
        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()

        # Rotate points by rotation
        predictions_worst_case = torch.ones_like(target, dtype=torch.bool)
        predictions_avg_case = torch.zeros_like(target, dtype=torch.float)
        num_trials = 0
        for theta in thetas:
            for gamma in gammas:
                num_trials += 1
                rot_xy = torch.tensor([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
                                      device=torch.device('cuda'), dtype=torch.float)
                rot_yz = torch.tensor([[np.cos(gamma), -np.sin(gamma)], [np.sin(gamma), np.cos(gamma)]],
                                      device=torch.device('cuda'), dtype=torch.float)
                # Rotate points
                points[:, :3, 0:2] = points[:, :3, 0:2].matmul(rot_xy)
                points[:, :3, 1:3] = points[:, :3, 1:3].matmul(rot_yz)

                # Rotate normals
                if points.shape[1] == 6:
                    points[:, 3:, 0:2] = points[:, 3:, 0:2].matmul(rot_xy)
                    points[:, 3:, 1:3] = points[:, 3:, 1:3].matmul(rot_yz)

                pred, _ = classifier(points)
                pred_choice = pred.argmax(dim=1)

                predictions_worst_case = predictions_worst_case & (pred_choice == target)
                predictions_avg_case += (pred_choice == target).float()

        # Process average case predictions
        predictions_avg_case = predictions_avg_case / num_trials
        for cls in np.unique(target.cpu()):
            classacc = predictions_avg_case[target == cls].cpu().sum()
            class_acc_avg[cls, 0] += classacc.item() / float(predictions_avg_case[target == cls].shape[0])
            class_acc_avg[cls, 1] += 1
        mean_correct_avg.append(predictions_avg_case.sum().cpu() / float(points.size()[0]))
        # Process worst-case predictions if any
        for cls in np.unique(target.cpu()):
            classacc = predictions_worst_case[target == cls].cpu().sum()
            class_acc_worst[cls, 0] += classacc.item() / float(predictions_worst_case[target == cls].shape[0])
            class_acc_worst[cls, 1] += 1
        correct = predictions_worst_case.eq(target.long().data).cpu().sum()
        mean_correct_worst.append(predictions_worst_case.sum().cpu() / float(points.size()[0]))

    class_acc_avg[:, 2] = class_acc_avg[:, 0] / class_acc_avg[:, 1]
    avg_class_acc = np.mean(class_acc_avg[:, 2])
    avg_instance_acc = np.mean(mean_correct_avg)

    class_acc_worst[:, 2] = class_acc_worst[:, 0] / class_acc_worst[:, 1]
    worst_class_acc = np.mean(class_acc_worst[:, 2])
    worst_instance_acc = np.mean(mean_correct_worst)

    return avg_instance_acc, avg_class_acc, worst_instance_acc, worst_class_acc


def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    '''DATA LOADING'''
    log_string('Load dataset ...')
    DATA_PATH = 'data/modelnet40_normal_resampled/'

    TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train',
                                       normal_channel=args.normal)
    TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test',
                                      normal_channel=args.normal)
    trainDataLoader = torch.utils.data.DataLoader(
        TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=4)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)

    '''MODEL LOADING'''
    num_class = 40
    MODEL = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))

    classifier = MODEL.get_model(num_class, normal_channel=args.normal).cuda()
    criterion = MODEL.get_loss().cuda()

    try:
        checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:  # noqa
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    else:
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    mean_correct = []

    '''TRANING'''
    log_string('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            points, target = data
            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            target = target[:, 0]

            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()

            classifier = classifier.train()
            pred, trans_feat = classifier(points)
            loss = criterion(pred, target.long(), trans_feat)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1

        train_instance_acc = np.mean(mean_correct)
        log_string('Train Instance Accuracy: %f' % train_instance_acc)

        with torch.no_grad():
            instance_acc, class_acc, _, _ = test(classifier.eval(), testDataLoader)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'instance_acc': instance_acc,
                    'class_acc': class_acc,
                    'model_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    log_string('End of training... Testing rotations now')
    instance_acc, class_acc, worst_case_instance_acc, worst_case_class_acc = test(classifier, testDataLoader, rotate=True)
    log_string(f'Worst-case Test Instance Accuracy: {worst_case_instance_acc} - '
               f'Average-case Test Instance Accuracy: {instance_acc}')
    log_string(f'Worst-case Test Class Accuracy: {worst_case_class_acc} - '
               f'Average-case Test Class Accuracy: {class_acc}')


if __name__ == '__main__':
    args = parse_args()
    main(args)
